{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('max_rows', 500)\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 5)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>BIO_anno</th>\n",
       "      <th>class</th>\n",
       "      <th>bank_topic</th>\n",
       "      <th>id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>交行14年用过，半年准备提额，却直接被降到1Ｋ，半年期间只T过一次三千，其它全部真实消费，第...</td>\n",
       "      <td>B-BANK I-BANK O O O O O O O O O O B-COMMENTS_N...</td>\n",
       "      <td>0</td>\n",
       "      <td>建设银行</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>单标我有了，最近visa双标返现活动好</td>\n",
       "      <td>B-PRODUCT I-PRODUCT O O O O O O B-PRODUCT I-PR...</td>\n",
       "      <td>1</td>\n",
       "      <td>建设银行</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>建设银行提额很慢的……</td>\n",
       "      <td>B-BANK I-BANK I-BANK I-BANK B-COMMENTS_N I-COM...</td>\n",
       "      <td>0</td>\n",
       "      <td>建设银行</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>孙女士在原恒泰农村合作银行存入50万元，同年9月又存款50万元。2014年，恒泰农村合作银行...</td>\n",
       "      <td>O O O O O B-BANK I-BANK I-BANK I-BANK I-BANK I...</td>\n",
       "      <td>2</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>我的怎么显示0.25费率，而且不管分多少期都一样费率，可惜只有69k</td>\n",
       "      <td>O O O O O O O O O O B-COMMENTS_N I-COMMENTS_N ...</td>\n",
       "      <td>2</td>\n",
       "      <td>建设银行</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                text  \\\n",
       "0  交行14年用过，半年准备提额，却直接被降到1Ｋ，半年期间只T过一次三千，其它全部真实消费，第...   \n",
       "1                                单标我有了，最近visa双标返现活动好   \n",
       "2                                        建设银行提额很慢的……   \n",
       "3  孙女士在原恒泰农村合作银行存入50万元，同年9月又存款50万元。2014年，恒泰农村合作银行...   \n",
       "4                 我的怎么显示0.25费率，而且不管分多少期都一样费率，可惜只有69k   \n",
       "\n",
       "                                            BIO_anno  class bank_topic  id  \n",
       "0  B-BANK I-BANK O O O O O O O O O O B-COMMENTS_N...      0       建设银行   0  \n",
       "1  B-PRODUCT I-PRODUCT O O O O O O B-PRODUCT I-PR...      1       建设银行   1  \n",
       "2  B-BANK I-BANK I-BANK I-BANK B-COMMENTS_N I-COM...      0       建设银行   2  \n",
       "3  O O O O O B-BANK I-BANK I-BANK I-BANK I-BANK I...      2        NaN   3  \n",
       "4  O O O O O O O O O O B-COMMENTS_N I-COMMENTS_N ...      2       建设银行   4  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('raw_data/train_data_public.csv', usecols=['text', 'BIO_anno', 'class', 'bank_topic'])\n",
    "train['id'] = [i for i in range(len(train))]\n",
    "\n",
    "print(train.shape)\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5093, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>一个月有不到40笔吗？卡多了也是累赘。再看就不能打电话交保护费保险了。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>光大│福卡13K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>对呀，才四万，又不是很多</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>为啥我10万每个月800啊，这个额度分分钟</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>3. 债券，一般有多余的钱可以考虑的，还是不错的投资</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                 text\n",
       "0   0  一个月有不到40笔吗？卡多了也是累赘。再看就不能打电话交保护费保险了。\n",
       "1   1                             光大│福卡13K\n",
       "2   2                         对呀，才四万，又不是很多\n",
       "3   3                为啥我10万每个月800啊，这个额度分分钟\n",
       "4   4           3. 债券，一般有多余的钱可以考虑的，还是不错的投资"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = pd.read_csv('raw_data/test_public.csv')\n",
    "\n",
    "print(test.shape)\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5093, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>BIO_anno</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>O O O O O O O O O O O O O O O O O O O O O O O ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>O O O O O O O O</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>O O O O O O O O O O O O</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>O O O O O O O O O O O O O O O O O O O O O</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>O O O O O O O O O O O O O O O O O O O O O O O ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                           BIO_anno  class\n",
       "0   0  O O O O O O O O O O O O O O O O O O O O O O O ...      2\n",
       "1   1                                    O O O O O O O O      2\n",
       "2   2                            O O O O O O O O O O O O      2\n",
       "3   3          O O O O O O O O O O O O O O O O O O O O O      2\n",
       "4   4  O O O O O O O O O O O O O O O O O O O O O O O ...      2"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "submit = pd.read_csv('raw_data/submit_example.csv')\n",
    "\n",
    "print(submit.shape)\n",
    "submit.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from simpletransformers.ner import NERModel, NERArgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5f4c0daeb73c4a078649481dc490be25",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "(210429, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence_id</th>\n",
       "      <th>words</th>\n",
       "      <th>labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>交</td>\n",
       "      <td>B-BANK</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>行</td>\n",
       "      <td>I-BANK</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>年</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "      <td>用</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>过</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0</td>\n",
       "      <td>半</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0</td>\n",
       "      <td>年</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0</td>\n",
       "      <td>准</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0</td>\n",
       "      <td>备</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0</td>\n",
       "      <td>提</td>\n",
       "      <td>B-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0</td>\n",
       "      <td>额</td>\n",
       "      <td>I-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0</td>\n",
       "      <td>却</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0</td>\n",
       "      <td>直</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0</td>\n",
       "      <td>接</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0</td>\n",
       "      <td>被</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0</td>\n",
       "      <td>降</td>\n",
       "      <td>B-COMMENTS_ADJ</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0</td>\n",
       "      <td>到</td>\n",
       "      <td>I-COMMENTS_ADJ</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0</td>\n",
       "      <td>Ｋ</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0</td>\n",
       "      <td>半</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0</td>\n",
       "      <td>年</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0</td>\n",
       "      <td>期</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0</td>\n",
       "      <td>间</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0</td>\n",
       "      <td>只</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0</td>\n",
       "      <td>T</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>0</td>\n",
       "      <td>过</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>0</td>\n",
       "      <td>一</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>0</td>\n",
       "      <td>次</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>0</td>\n",
       "      <td>三</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>0</td>\n",
       "      <td>千</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>0</td>\n",
       "      <td>其</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>0</td>\n",
       "      <td>它</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>0</td>\n",
       "      <td>全</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>0</td>\n",
       "      <td>部</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>0</td>\n",
       "      <td>真</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>0</td>\n",
       "      <td>实</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>0</td>\n",
       "      <td>消</td>\n",
       "      <td>B-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>0</td>\n",
       "      <td>费</td>\n",
       "      <td>I-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>0</td>\n",
       "      <td>第</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>0</td>\n",
       "      <td>六</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>0</td>\n",
       "      <td>个</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>0</td>\n",
       "      <td>月</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>0</td>\n",
       "      <td>的</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>0</td>\n",
       "      <td>时</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>51</th>\n",
       "      <td>0</td>\n",
       "      <td>候</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>52</th>\n",
       "      <td>0</td>\n",
       "      <td>为</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>53</th>\n",
       "      <td>0</td>\n",
       "      <td>了</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54</th>\n",
       "      <td>0</td>\n",
       "      <td>增</td>\n",
       "      <td>B-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>55</th>\n",
       "      <td>0</td>\n",
       "      <td>加</td>\n",
       "      <td>I-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56</th>\n",
       "      <td>0</td>\n",
       "      <td>评</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>57</th>\n",
       "      <td>0</td>\n",
       "      <td>分</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>58</th>\n",
       "      <td>0</td>\n",
       "      <td>提</td>\n",
       "      <td>B-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>59</th>\n",
       "      <td>0</td>\n",
       "      <td>额</td>\n",
       "      <td>I-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>60</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>0</td>\n",
       "      <td>还</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>62</th>\n",
       "      <td>0</td>\n",
       "      <td>特</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>63</th>\n",
       "      <td>0</td>\n",
       "      <td>意</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>0</td>\n",
       "      <td>分</td>\n",
       "      <td>B-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>0</td>\n",
       "      <td>期</td>\n",
       "      <td>I-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>0</td>\n",
       "      <td>两</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>67</th>\n",
       "      <td>0</td>\n",
       "      <td>万</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>68</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>69</th>\n",
       "      <td>0</td>\n",
       "      <td>但</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>0</td>\n",
       "      <td>降</td>\n",
       "      <td>B-COMMENTS_ADJ</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>71</th>\n",
       "      <td>0</td>\n",
       "      <td>额</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>0</td>\n",
       "      <td>后</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>0</td>\n",
       "      <td>电</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>74</th>\n",
       "      <td>0</td>\n",
       "      <td>话</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>0</td>\n",
       "      <td>投</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>76</th>\n",
       "      <td>0</td>\n",
       "      <td>诉</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>0</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>0</td>\n",
       "      <td>申</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>0</td>\n",
       "      <td>请</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80</th>\n",
       "      <td>0</td>\n",
       "      <td>提</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>81</th>\n",
       "      <td>0</td>\n",
       "      <td>.</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>82</th>\n",
       "      <td>0</td>\n",
       "      <td>.</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83</th>\n",
       "      <td>0</td>\n",
       "      <td>.</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>84</th>\n",
       "      <td>1</td>\n",
       "      <td>单</td>\n",
       "      <td>B-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>85</th>\n",
       "      <td>1</td>\n",
       "      <td>标</td>\n",
       "      <td>I-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>86</th>\n",
       "      <td>1</td>\n",
       "      <td>我</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>87</th>\n",
       "      <td>1</td>\n",
       "      <td>有</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>88</th>\n",
       "      <td>1</td>\n",
       "      <td>了</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>89</th>\n",
       "      <td>1</td>\n",
       "      <td>，</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>1</td>\n",
       "      <td>最</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>91</th>\n",
       "      <td>1</td>\n",
       "      <td>近</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>92</th>\n",
       "      <td>1</td>\n",
       "      <td>v</td>\n",
       "      <td>B-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>93</th>\n",
       "      <td>1</td>\n",
       "      <td>i</td>\n",
       "      <td>I-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>94</th>\n",
       "      <td>1</td>\n",
       "      <td>s</td>\n",
       "      <td>I-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>1</td>\n",
       "      <td>a</td>\n",
       "      <td>I-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>1</td>\n",
       "      <td>双</td>\n",
       "      <td>B-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>1</td>\n",
       "      <td>标</td>\n",
       "      <td>I-PRODUCT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>1</td>\n",
       "      <td>返</td>\n",
       "      <td>B-COMMENTS_N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>1</td>\n",
       "      <td>现</td>\n",
       "      <td>I-COMMENTS_N</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    sentence_id words          labels\n",
       "0             0     交          B-BANK\n",
       "1             0     行          I-BANK\n",
       "2             0     1               O\n",
       "3             0     4               O\n",
       "4             0     年               O\n",
       "5             0     用               O\n",
       "6             0     过               O\n",
       "7             0     ，               O\n",
       "8             0     半               O\n",
       "9             0     年               O\n",
       "10            0     准               O\n",
       "11            0     备               O\n",
       "12            0     提    B-COMMENTS_N\n",
       "13            0     额    I-COMMENTS_N\n",
       "14            0     ，               O\n",
       "15            0     却               O\n",
       "16            0     直               O\n",
       "17            0     接               O\n",
       "18            0     被               O\n",
       "19            0     降  B-COMMENTS_ADJ\n",
       "20            0     到  I-COMMENTS_ADJ\n",
       "21            0     1               O\n",
       "22            0     Ｋ               O\n",
       "23            0     ，               O\n",
       "24            0     半               O\n",
       "25            0     年               O\n",
       "26            0     期               O\n",
       "27            0     间               O\n",
       "28            0     只               O\n",
       "29            0     T               O\n",
       "30            0     过               O\n",
       "31            0     一               O\n",
       "32            0     次               O\n",
       "33            0     三               O\n",
       "34            0     千               O\n",
       "35            0     ，               O\n",
       "36            0     其               O\n",
       "37            0     它               O\n",
       "38            0     全               O\n",
       "39            0     部               O\n",
       "40            0     真               O\n",
       "41            0     实               O\n",
       "42            0     消    B-COMMENTS_N\n",
       "43            0     费    I-COMMENTS_N\n",
       "44            0     ，               O\n",
       "45            0     第               O\n",
       "46            0     六               O\n",
       "47            0     个               O\n",
       "48            0     月               O\n",
       "49            0     的               O\n",
       "50            0     时               O\n",
       "51            0     候               O\n",
       "52            0     为               O\n",
       "53            0     了               O\n",
       "54            0     增    B-COMMENTS_N\n",
       "55            0     加    I-COMMENTS_N\n",
       "56            0     评               O\n",
       "57            0     分               O\n",
       "58            0     提    B-COMMENTS_N\n",
       "59            0     额    I-COMMENTS_N\n",
       "60            0     ，               O\n",
       "61            0     还               O\n",
       "62            0     特               O\n",
       "63            0     意               O\n",
       "64            0     分       B-PRODUCT\n",
       "65            0     期       I-PRODUCT\n",
       "66            0     两               O\n",
       "67            0     万               O\n",
       "68            0     ，               O\n",
       "69            0     但               O\n",
       "70            0     降  B-COMMENTS_ADJ\n",
       "71            0     额               O\n",
       "72            0     后               O\n",
       "73            0     电               O\n",
       "74            0     话               O\n",
       "75            0     投               O\n",
       "76            0     诉               O\n",
       "77            0     ，               O\n",
       "78            0     申               O\n",
       "79            0     请               O\n",
       "80            0     提               O\n",
       "81            0     .               O\n",
       "82            0     .               O\n",
       "83            0     .               O\n",
       "84            1     单       B-PRODUCT\n",
       "85            1     标       I-PRODUCT\n",
       "86            1     我               O\n",
       "87            1     有               O\n",
       "88            1     了               O\n",
       "89            1     ，               O\n",
       "90            1     最               O\n",
       "91            1     近               O\n",
       "92            1     v       B-PRODUCT\n",
       "93            1     i       I-PRODUCT\n",
       "94            1     s       I-PRODUCT\n",
       "95            1     a       I-PRODUCT\n",
       "96            1     双       B-PRODUCT\n",
       "97            1     标       I-PRODUCT\n",
       "98            1     返    B-COMMENTS_N\n",
       "99            1     现    I-COMMENTS_N"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data = list()\n",
    "\n",
    "for _, row in tqdm(train.iterrows()):\n",
    "    sentence_id = row['id']\n",
    "    sentence = row['text']\n",
    "    bios = row['BIO_anno'].split()\n",
    "    for idx, char in enumerate(sentence):\n",
    "        res = dict()\n",
    "        res['sentence_id'] = sentence_id\n",
    "        res['words'] = char\n",
    "        res['labels'] = bios[idx]\n",
    "        train_data.append(res)\n",
    "        \n",
    "train_data = pd.DataFrame(train_data)\n",
    "\n",
    "print(train_data.shape)\n",
    "train_data.head(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [\n",
    "    'B-BANK',\n",
    "    'I-BANK',\n",
    "    'B-PRODUCT',\n",
    "    'I-PRODUCT',\n",
    "    'B-COMMENTS_N',\n",
    "    'I-COMMENTS_N',\n",
    "    'B-COMMENTS_ADJ',\n",
    "    'I-COMMENTS_ADJ',\n",
    "    'O'\n",
    "]\n",
    "\n",
    "model_args = NERArgs()\n",
    "model_args.max_seq_length = 400\n",
    "model_args.train_batch_size = 8\n",
    "model_args.num_train_epochs = 3\n",
    "model_args.fp16 = False\n",
    "model_args.evaluate_during_training = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-bert-wwm-ext were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at hfl/chinese-bert-wwm-ext and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = NERModel(\"bert\", \n",
    "                 \"hfl/chinese-bert-wwm-ext\",\n",
    "                 labels=labels,\n",
    "                 args=model_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "798b2baf952e42d0901f01f9580655f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b8f4873d4ae4e7ca05c7454a4189115",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42c32403858f41aaa189f0a12c2e6aba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=1125.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32b180f32f194cf9b8772bd05f4a3c3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dcd5e241bfaf4e939be75c808fc008ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=125.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be348012270c44a5a436bd2a563ad897",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=1125.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0712537fc9846bb94d80b8cdc2ca9f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42166f6fecbe451b9b8bc1a9770f7696",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=125.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "940a70c8299a4d9484bee984fd0a2cee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6eb4b96d843f40219b69c97d995a27d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=125.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a02b1b8b1f943a8b897757c87a134d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=1125.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17e3b59085794900ba3a04a2a5c8385e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4544d6761d1e4363a6a566c50173699e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=125.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(3375,\n",
       " {'global_step': [1125, 2000, 2250, 3375],\n",
       "  'precision': [0.848078096400244,\n",
       "   0.8239067055393586,\n",
       "   0.8167808219178082,\n",
       "   0.8505002942907592],\n",
       "  'recall': [0.8671241422333126,\n",
       "   0.8814722395508422,\n",
       "   0.8927011852776044,\n",
       "   0.9014348097317529],\n",
       "  'f1_score': [0.8574953732264035,\n",
       "   0.8517179023508137,\n",
       "   0.8530551415797317,\n",
       "   0.8752271350696547],\n",
       "  'train_loss': [0.1103605329990387,\n",
       "   0.2675149142742157,\n",
       "   0.09671870619058609,\n",
       "   0.05714509263634682],\n",
       "  'eval_loss': [0.10910721370577812,\n",
       "   0.10822772925719619,\n",
       "   0.10991965884529054,\n",
       "   0.11606662555783986]})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = train_data[train_data['sentence_id'] < 9000]\n",
    "valid_df = train_data[train_data['sentence_id'] >= 9000]\n",
    "\n",
    "model.train_model(train_df, eval_data=valid_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ed112f814b74e22a37182932bf2f0fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ae2fe25d8ee47d89b318385828016b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Prediction', max=637.0, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "pred, _ = model.predict(test['text'], split_on_space=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9a76956214c4f2bb57a94621ae8c7dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>BIO_anno</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>O O O O O O O O O O O O O O O O B-COMMENTS_ADJ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>B-BANK I-BANK O B-PRODUCT I-PRODUCT O O O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>O O O O O O O O O O B-COMMENTS_ADJ I-COMMENTS_ADJ</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>O O O O O O O O O O O O O O O O B-COMMENTS_N I...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>O O B-PRODUCT I-PRODUCT O O O O B-COMMENTS_ADJ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                           BIO_anno\n",
       "0   0  O O O O O O O O O O O O O O O O B-COMMENTS_ADJ...\n",
       "1   1          B-BANK I-BANK O B-PRODUCT I-PRODUCT O O O\n",
       "2   2  O O O O O O O O O O B-COMMENTS_ADJ I-COMMENTS_ADJ\n",
       "3   3  O O O O O O O O O O O O O O O O B-COMMENTS_N I...\n",
       "4   4  O O B-PRODUCT I-PRODUCT O O O O B-COMMENTS_ADJ..."
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub_data = list()\n",
    "\n",
    "for idx, row in tqdm(test.iterrows()):\n",
    "    res = dict()\n",
    "    res['id'] = idx\n",
    "    res['BIO_anno'] = ' '.join([list(i.values())[0] for i in pred[idx]])\n",
    "    sub_data.append(res)\n",
    "    \n",
    "sub = pd.DataFrame(sub_data)\n",
    "sub.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub.to_csv('ner_result.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from simpletransformers.classification import ClassificationModel, ClassificationArgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2    9620\n",
       "0     242\n",
       "1     138\n",
       "Name: class, dtype: int64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train['class'].value_counts(dropna=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_args = ClassificationArgs()\n",
    "model_args.max_seq_length = 400\n",
    "model_args.train_batch_size = 8\n",
    "model_args.num_train_epochs = 3\n",
    "model_args.fp16 = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-bert-wwm-ext were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/chinese-bert-wwm-ext and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = ClassificationModel(\"bert\",\n",
    "                            \"hfl/chinese-bert-wwm-ext\",\n",
    "                            num_labels=3,\n",
    "                            args=model_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = train[['text', 'class']].copy()\n",
    "train_df.columns = ['text', 'labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f597235cd92247598e783a746176d9f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d34dae662144501926a364efc5f93f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37165c0e121441089bb7c852abd1497a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 3', max=1250.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "980c6c172be9486d9718abcc4577f016",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 3', max=1250.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b5a14f80d9049439874403699e0784f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 3', max=1250.0, style=ProgressStyle(de…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(3750, 0.17493137988795837)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.train_model(train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c57aa5f77f7b4a1ebe3edfafcb62ee34",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=5093.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bebe052f104d4ff1808ea7fd1090ab19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=637.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "pred, _ = model.predict(list(test['text']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>BIO_anno</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>O O O O O O O O O O O O O O O O B-COMMENTS_ADJ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>B-BANK I-BANK O B-PRODUCT I-PRODUCT O O O</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>O O O O O O O O O O B-COMMENTS_ADJ I-COMMENTS_ADJ</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>O O O O O O O O O O O O O O O O B-COMMENTS_N I...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>O O B-PRODUCT I-PRODUCT O O O O B-COMMENTS_ADJ...</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                           BIO_anno  class\n",
       "0   0  O O O O O O O O O O O O O O O O B-COMMENTS_ADJ...      2\n",
       "1   1          B-BANK I-BANK O B-PRODUCT I-PRODUCT O O O      2\n",
       "2   2  O O O O O O O O O O B-COMMENTS_ADJ I-COMMENTS_ADJ      2\n",
       "3   3  O O O O O O O O O O O O O O O O B-COMMENTS_N I...      2\n",
       "4   4  O O B-PRODUCT I-PRODUCT O O O O B-COMMENTS_ADJ...      2"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub = pd.read_csv('ner_result.csv')\n",
    "sub['class'] = pred\n",
    "sub.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2    4998\n",
       "0      64\n",
       "1      31\n",
       "Name: class, dtype: int64"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub['class'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub.to_csv('baseline.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
